package cn.bugstack.openai.executor.model.chatglm;

import cn.bugstack.openai.executor.Executor;
import cn.bugstack.openai.executor.model.chatglm.config.ChatGLMConfig;
import cn.bugstack.openai.executor.model.chatglm.valobj.ChatGLMCompletionRequest;
import cn.bugstack.openai.executor.model.chatglm.valobj.ChatGLMCompletionResponse;
import cn.bugstack.openai.executor.model.chatglm.valobj.Model;
import cn.bugstack.openai.executor.parameter.*;
import cn.bugstack.openai.executor.result.ResultHandler;
import cn.bugstack.openai.session.Configuration;
import com.chatplus.application.common.util.PlusJsonUtils;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.NotNull;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * ChatGLM 模型执行器
 * <p>
 * 文档：https://open.bigmodel.cn/dev/api#chatglm_turbo
 * ApiKey：https://open.bigmodel.cn/usercenter/apikeys
 *
 * @author 小傅哥，微信：fustack
 */
@Slf4j
public class ChatGLMModelExecutor implements Executor, ParameterHandler<ChatGLMCompletionRequest>, ResultHandler {

    /**
     * 配置信息
     */
    private final ChatGLMConfig chatGLMConfig;
    /**
     * 工厂事件
     */
    private final EventSource.Factory factory;

    public ChatGLMModelExecutor(Configuration configuration) {
        this.chatGLMConfig = configuration.getChatGLMConfig();
        this.factory = configuration.createRequestFactory();
    }

    @Override
    public EventSource completions(CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        return completions(null, null, completionRequest, eventSourceListener);
    }

    @Override
    public EventSource completions(String apiHostByUser, String apiKeyByUser, CompletionRequest completionRequest, EventSourceListener eventSourceListener) throws Exception {
        // 1. 转换参数信息
        ChatGLMCompletionRequest chatGLMCompletionRequest = getParameterObject(completionRequest);
        // 3. 构建请求信息
        Request request = new Request.Builder()
                .tag(completionRequest.getTag())
                .url(ChatGLMConfig.API_HOST.concat(ChatGLMConfig.V4_COMPLETIONS))
                .post(RequestBody.create(chatGLMCompletionRequest.toString(), MediaType.parse(Configuration.JSON_CONTENT_TYPE)))
                .build();
        // 4. 返回事件结果
        return factory.newEventSource(request, eventSourceListener(eventSourceListener));
    }

    @Override
    public ImageResponse genImages(ImageRequest imageRequest) {
        return null;
    }

    @Override
    public ImageResponse genImages(String apiHostByUser, String apiKeyByUser, ImageRequest imageRequest) {
        return null;
    }

    @Override
    public EventSource pictureUnderstanding(PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public EventSource pictureUnderstanding(String apiHostByUser, String apiKeyByUser, PictureRequest pictureRequest, EventSourceListener eventSourceListener) throws Exception {
        return null;
    }

    @Override
    public ChatGLMCompletionRequest getParameterObject(CompletionRequest completionRequest) {

        ChatGLMCompletionRequest chatGLMCompletionRequest = new ChatGLMCompletionRequest();
        chatGLMCompletionRequest.setModel(Model.getModel(completionRequest.getModel()));
        chatGLMCompletionRequest.setStream(completionRequest.isStream());
//        chatGLMCompletionRequest.setTemperature(completionRequest.getTemperature());
//        chatGLMCompletionRequest.setTopP(completionRequest.getTopP());
        List<ChatGLMCompletionRequest.Prompt> prompts = new ArrayList<>();

        // 重新组装参数，ChatGLM 需要用 Okay 间隔历史消息
        List<Message> messages = completionRequest.getMessages();
        for (int i = 0; i < messages.size(); i++) {
            Message message = messages.get(i);
            if (0 == i) {
                prompts.add(ChatGLMCompletionRequest.Prompt.builder()
                        .role(message.getRole())
                        .content(message.getContent())
                        .build());
            } else {
                String role = message.getRole();
                if (Objects.equals(role, CompletionRequest.Role.SYSTEM.getCode())) {
                    prompts.add(ChatGLMCompletionRequest.Prompt.builder()
                            .role(CompletionRequest.Role.SYSTEM.getCode())
                            .content(message.getContent())
                            .build());
//                    prompts.add(ChatGLMCompletionRequest.Prompt.builder()
//                            .role(CompletionRequest.Role.USER.getCode())
//                            .content("Okay")
//                            .build());
                } else {
                    prompts.add(ChatGLMCompletionRequest.Prompt.builder()
                            .role(CompletionRequest.Role.USER.getCode())
                            .content(message.getContent())
                            .build());
//                    prompts.add(ChatGLMCompletionRequest.Prompt.builder()
//                            .role(CompletionRequest.Role.USER.getCode())
//                            .content("Okay")
//                            .build());
                }
            }
        }

        chatGLMCompletionRequest.setMessages(prompts);

        return chatGLMCompletionRequest;
    }

    @Override
    public boolean supportFunction(String model) {
        return false;
    }

    @Override
    public EventSourceListener eventSourceListener(EventSourceListener eventSourceListener) {
        return new EventSourceListener() {
            @Override
            public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
                if ("[DONE]".equals(data)) {
                    return;
                }
                ChatGLMCompletionResponse response = PlusJsonUtils.parseObject(data, ChatGLMCompletionResponse.class);
                if (null == response) {
                    return;
                }
                // 构建结果
                CompletionResponse completionResponse = new CompletionResponse();
                completionResponse.setChoices(response.getChoices());
                completionResponse.setUsage(response.getUsage());
                completionResponse.setCreated(System.currentTimeMillis());
                // 返回数据
                eventSourceListener.onEvent(eventSource, id, type, PlusJsonUtils.toJsonString(completionResponse));
            }

            @Override
            public void onClosed(@NotNull EventSource eventSource) {
                eventSourceListener.onClosed(eventSource);
            }

            @Override
            public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
                eventSourceListener.onOpen(eventSource, response);
            }

            @Override
            public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
                eventSourceListener.onFailure(eventSource, t, response);
            }

        };
    }

}
